{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "原文代码作者：https://github.com/wzyonggege/statistical-learning-method"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第8章 提升方法AdaBoost算法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Boost\n",
    "\n",
    "“装袋”（bagging）和“提升”（boost）是构建组合模型的两种最主要的方法，所谓的组合模型是由多个基本模型构成的模型，组合模型的预测效果往往比任意一个基本模型的效果都要好。\n",
    "\n",
    "- 装袋：每个基本模型由从总体样本中随机抽样得到的不同数据集进行训练得到，通过重抽样得到不同训练数据集的过程称为装袋。\n",
    "\n",
    "- 提升：每个基本模型训练时的数据集采用不同权重，针对上一个基本模型分类错误的样本增加权重，使得新的模型重点关注误分类样本\n",
    "\n",
    "### AdaBoost\n",
    "\n",
    "AdaBoost是AdaptiveBoost的缩写，表明该算法是具有适应性的提升算法。\n",
    "\n",
    "算法的步骤如下：\n",
    "\n",
    "1）给每个训练样本$(x_1,x_2,\\ldots,x_N)$分配权重，初始权重$w_{1}$均为$1/N$。\n",
    "\n",
    "2）针对带有权值的样本进行训练，得到模型$G_m$（初始模型为$G_1$）。\n",
    "\n",
    "3）计算模型$G_m$的误分率$e_m=\\sum_{i=1}^Nw_iI(y_i\\not= G_m(x_i))$\n",
    "\n",
    "4）计算模型$G_m$的系数$\\alpha_m=\\frac{1}{2}\\log[\\frac{(1-e_m)}{e_m}]$\n",
    "\n",
    "5）根据误分率e和当前权重向量$w_m$更新权重向量$w_{m+1}$。\n",
    "\n",
    "6）计算组合模型$f(x)=\\sum_{m=1}^M\\alpha_mG_m(x_i)$的误分率。\n",
    "\n",
    "7）当组合模型的误分率或迭代次数低于一定阈值，停止迭代；否则，回到步骤2）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection  import train_test_split\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化数据\n",
    "def create_data():\n",
    "    # 加载鸢尾花数据集\n",
    "    iris = load_iris()\n",
    "    # 生成dataframe数据\n",
    "    df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
    "    df['label'] = iris.target\n",
    "    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n",
    "    data = np.array(df.iloc[:100, [0, 1, -1]])\n",
    "    for i in range(len(data)):\n",
    "        if data[i,-1] == 0:\n",
    "            data[i,-1] = -1\n",
    "    return data[:,:2], data[:,-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = create_data()\n",
    "# 进行数据集交叉验证，20%测试数据，80%的训练数据\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x1a8cff60>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAGcZJREFUeJzt3X+MXWWdx/H3d4dZOiowoYwrzJQtWNMo0LUwgqQJccHdaq2lQba08VeVlV2DCwYXI4agNiZgSPDHkmj4kQWELXYrlsLyYxGW+CNSMwVs11YiCtoZ2GUotshaoB2++8e9Q2fu3Jl7n3vvmfs8z/28kqZzz316+n3O0S+353zOc83dERGRvPxZuwsQEZHWU3MXEcmQmruISIbU3EVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7iIiGTqk3oFm1gUMASPuvrzivbXA1cBIedO17n7DTPs76qijfP78+UHFioh0uq1btz7v7n21xtXd3IGLgZ3A4dO8/z13/0y9O5s/fz5DQ0MBf72IiJjZ7+oZV9dlGTMbAD4AzPhpXERE4lDvNfdvAJ8HXpthzIfMbJuZbTSzedUGmNkFZjZkZkOjo6OhtYqISJ1qNnczWw485+5bZxh2FzDf3RcBPwRurjbI3a9z90F3H+zrq3nJSEREGlTPNfclwAozWwbMAQ43s1vd/SPjA9x994Tx1wNfa22ZIiKts3//foaHh3n55ZfbXcq05syZw8DAAN3d3Q39+ZrN3d0vAy4DMLP3AP88sbGXtx/t7s+WX66gdONVRCRKw8PDHHbYYcyfPx8za3c5U7g7u3fvZnh4mOOOO66hfTScczezdWa2ovzyIjP7pZn9ArgIWNvofkVEivbyyy8zd+7cKBs7gJkxd+7cpv5lERKFxN0fBh4u/3zFhO2vf7oXyc2mx0a4+v4neGbPPo7p7eHSpQtZubi/3WVJk2Jt7OOarS+ouYt0mk2PjXDZHdvZt38MgJE9+7jsju0AavASNS0/IDKDq+9/4vXGPm7f/jGuvv+JNlUkubjvvvtYuHAhCxYs4Kqrrmr5/tXcRWbwzJ59QdtF6jE2NsaFF17Ivffey44dO1i/fj07duxo6d+hyzIiMzimt4eRKo38mN6eNlQj7dLq+y4///nPWbBgAccffzwAq1ev5s477+Qd73hHq0rWJ3eRmVy6dCE93V2TtvV0d3Hp0oVtqkhm2/h9l5E9+3AO3nfZ9NhIzT87nZGREebNO/gg/8DAACMjje+vGjV3kRmsXNzPleecRH9vDwb09/Zw5Tkn6WZqBynivou7T9nW6vSOLsuI1LBycb+aeQcr4r7LwMAAu3btev318PAwxxxzTMP7q0af3EVEZjDd/ZVm7ru8613v4te//jVPPfUUr776KrfffjsrVqyo/QcDqLmLiMygiPsuhxxyCNdeey1Lly7l7W9/O6tWreKEE05ottTJf0dL9yYikpnxS3Ktfkp52bJlLFu2rBUlVqXmLiJSQ4r3XXRZRkQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLtnY9NgIS656iOO+8B8sueqhptb+ECnaJz/5Sd785jdz4oknFrJ/NXfJQhGLO4kUae3atdx3332F7V/NXbKgL9WQQm3bAF8/Eb7cW/p924amd3nGGWdw5JFHtqC46vQQk2RBX6ohhdm2Ae66CPaX/7e0d1fpNcCiVe2rqwZ9cpcsFLG4kwgAD6472NjH7d9X2h4xNXfJgr5UQwqzdzhseyR0WUayUNTiTiIcMVC6FFNte8TU3CUbKS7uJAk464rJ19wBuntK25uwZs0aHn74YZ5//nkGBgb4yle+wvnnn99ksQepuUvTWv3lwSJRGb9p+uC60qWYIwZKjb3Jm6nr169vQXHTU3OXpozny8djiOP5ckANXvKxaFXUyZhqdENVmqJ8uUic1NylKcqXS6rcvd0lzKjZ+tTcpSnKl0uK5syZw+7du6Nt8O7O7t27mTNnTsP70DV3acqlSxdOuuYOypdL/AYGBhgeHmZ0dLTdpUxrzpw5DAw0HrdUc5emKF8uKeru7ua4445rdxmFqru5m1kXMASMuPvyivcOBW4BTgF2A+e5+9MtrFMipny5SHxCPrlfDOwEDq/y3vnAH9x9gZmtBr4GnNeC+kSSosy/xKKuG6pmNgB8ALhhmiFnAzeXf94InGVm1nx5IunQmvISk3rTMt8APg+8Ns37/cAuAHc/AOwF5jZdnUhClPmXmNRs7ma2HHjO3bfONKzKtikZIzO7wMyGzGwo5rvUIo1Q5l9iUs8n9yXACjN7GrgdONPMbq0YMwzMAzCzQ4AjgBcqd+Tu17n7oLsP9vX1NVW4SGyU+ZeY1Gzu7n6Zuw+4+3xgNfCQu3+kYthm4OPln88tj4nz6QCRgmhNeYlJwzl3M1sHDLn7ZuBG4Ltm9iSlT+yrW1SfSDKU+ZeYWLs+YA8ODvrQ0FBb/m4RkVSZ2VZ3H6w1Tk+oSrQu37Sd9Vt2MeZOlxlrTpvHV1ee1O6yRJKg5i5RunzTdm595Pevvx5zf/21GrxIbVoVUqK0fkuV76ycYbuITKbmLlEam+Ze0HTbRWQyNXeJUtc0q1dMt11EJlNzlyitOW1e0HYRmUw3VCVK4zdNlZYRaYxy7iIiCVHOXZry4et/xk9/c3B5oCVvPZLbPnV6GytqH63RLinSNXeZorKxA/z0Ny/w4et/1qaK2kdrtEuq1NxlisrGXmt7zrRGu6RKzV1kBlqjXVKl5i4yA63RLqlSc5cplrz1yKDtOdMa7ZIqNXeZ4rZPnT6lkXdqWmbl4n6uPOck+nt7MKC/t4crzzlJaRmJnnLuIiIJUc5dmlJUtjtkv8qXizROzV2mGM92j0cAx7PdQFPNNWS/RdUg0il0zV2mKCrbHbJf5ctFmqPmLlMUle0O2a/y5SLNUXOXKYrKdofsV/lykeaoucsURWW7Q/arfLlIc3RDVaYYv2HZ6qRKyH6LqkGkUyjnLiKSEOXcCxZDBju0hhhqFpHZoebegBgy2KE1xFCziMwe3VBtQAwZ7NAaYqhZRGaPmnsDYshgh9YQQ80iMnvU3BsQQwY7tIYYahaR2aPm3oAYMtihNcRQs4jMHt1QbUAMGezQGmKoWURmT82cu5nNAX4EHErpPwYb3f1LFWPWAlcD418Jf6273zDTfpVzFxEJ18qc+yvAme7+kpl1Az8xs3vd/ZGKcd9z9880UqzMjss3bWf9ll2MudNlxprT5vHVlSc1PTaW/HwsdYjEoGZz99JH+5fKL7vLv9rzWKs07PJN27n1kd+//nrM/fXXlU07ZGws+flY6hCJRV03VM2sy8weB54DHnD3LVWGfcjMtpnZRjOb19IqpWnrt+yqe3vI2Fjy87HUIRKLupq7u4+5+zuBAeBUMzuxYshdwHx3XwT8ELi52n7M7AIzGzKzodHR0WbqlkBj09xbqbY9ZGws+flY6hCJRVAU0t33AA8D76vYvtvdXym/vB44ZZo/f527D7r7YF9fXwPlSqO6zOreHjI2lvx8LHWIxKJmczezPjPrLf/cA7wX+FXFmKMnvFwB7GxlkdK8NadVv1JWbXvI2Fjy87HUIRKLetIyRwM3m1kXpf8YbHD3u81sHTDk7puBi8xsBXAAeAFYW1TB0pjxG6H1JGBCxsaSn4+lDpFYaD13EZGEaD33ghWVqQ7Jlxe575D5pXgskrNtAzy4DvYOwxEDcNYVsGhVu6uSiKm5N6CoTHVIvrzIfYfML8VjkZxtG+Cui2B/Ofmzd1fpNajBy7S0cFgDispUh+TLi9x3yPxSPBbJeXDdwcY+bv++0naRaai5N6CoTHVIvrzIfYfML8VjkZy9w2HbRVBzb0hRmeqQfHmR+w6ZX4rHIjlHDIRtF0HNvSFFZapD8uVF7jtkfikei+ScdQV0V/zHsruntF1kGrqh2oCiMtUh+fIi9x0yvxSPRXLGb5oqLSMBlHMXEUmIcu4yRQzZdUmc8vbJUHPvEDFk1yVxytsnRTdUO0QM2XVJnPL2SVFz7xAxZNclccrbJ0XNvUPEkF2XxClvnxQ19w4RQ3ZdEqe8fVJ0Q7VDxJBdl8Qpb58U5dxFRBKinHtZUXntkP3Gsi65suuRyT0znvv8QrThWGTd3IvKa4fsN5Z1yZVdj0zumfHc5xeiTcci6xuqReW1Q/Yby7rkyq5HJvfMeO7zC9GmY5F1cy8qrx2y31jWJVd2PTK5Z8Zzn1+INh2LrJt7UXntkP3Gsi65suuRyT0znvv8QrTpWGTd3IvKa4fsN5Z1yZVdj0zumfHc5xeiTcci6xuqReW1Q/Yby7rkyq5HJvfMeO7zC9GmY6Gcu4hIQpRzL5jy8yKJuPsS2HoT+BhYF5yyFpZf0/x+I8/xq7k3QPl5kUTcfQkM3XjwtY8dfN1Mg08gx5/1DdWiKD8vkoitN4Vtr1cCOX419wYoPy+SCB8L216vBHL8au4NUH5eJBHWFba9Xgnk+NXcG6D8vEgiTlkbtr1eCeT4dUO1AcrPiyRi/KZpq9MyCeT4lXMXEUlIy3LuZjYH+BFwaHn8Rnf/UsWYQ4FbgFOA3cB57v50A3XXFJovT20N85Dseu7HotAccUj2uag6ipxf5BnspoTOLedjMYN6Lsu8Apzp7i+ZWTfwEzO7190fmTDmfOAP7r7AzFYDXwPOa3Wxofny1NYwD8mu534sCs0Rh2Sfi6qjyPklkMFuWOjccj4WNdS8oeolL5Vfdpd/VV7LORu4ufzzRuAss9bHNkLz5amtYR6SXc/9WBSaIw7JPhdVR5HzSyCD3bDQueV8LGqoKy1jZl1m9jjwHPCAu2+pGNIP7AJw9wPAXmBulf1cYGZDZjY0OjoaXGxovjy1NcxDsuu5H4tCc8Qh2eei6ihyfglksBsWOrecj0UNdTV3dx9z93cCA8CpZnZixZBqn9KndCR3v87dB919sK+vL7jY0Hx5amuYh2TXcz8WheaIQ7LPRdVR5PwSyGA3LHRuOR+LGoJy7u6+B3gYeF/FW8PAPAAzOwQ4AnihBfVNEpovT20N85Dseu7HotAccUj2uag6ipxfAhnshoXOLedjUUM9aZk+YL+77zGzHuC9lG6YTrQZ+DjwM+Bc4CEvIGMZmi9PbQ3zkOx67sei0BxxSPa5qDqKnF8CGeyGhc4t52NRQ82cu5ktonSztIvSJ/0N7r7OzNYBQ+6+uRyX/C6wmNIn9tXu/tuZ9qucu4hIuJbl3N19G6WmXbn9igk/vwz8XWiRIiJSjOyXH0juwR2ZHSEPtsTwEEyRD+6k9pBWDOcjAVk39+Qe3JHZEfJgSwwPwRT54E5qD2nFcD4SkfWqkMk9uCOzI+TBlhgeginywZ3UHtKK4XwkIuvmntyDOzI7Qh5sieEhmCIf3EntIa0Yzkcism7uyT24I7Mj5MGWGB6CKfLBndQe0orhfCQi6+ae3IM7MjtCHmyJ4SGYIh/cSe0hrRjORyKybu4rF/dz5Tkn0d/bgwH9vT1cec5Jupna6Ratgg9+C46YB1jp9w9+q/oNuZCxMdQbOr6o+aW23wzpyzpERBLSsoeYRDpeyBd7xCK1mmPJrsdSRwuouYvMJOSLPWKRWs2xZNdjqaNFsr7mLtK0kC/2iEVqNceSXY+ljhZRcxeZScgXe8QitZpjya7HUkeLqLmLzCTkiz1ikVrNsWTXY6mjRdTcRWYS8sUesUit5liy67HU0SJq7iIzWX4NDJ5/8FOvdZVex3hjclxqNceSXY+ljhZRzl1EJCHKucvsSTEbXFTNReXLUzzG0lZq7tKcFLPBRdVcVL48xWMsbadr7tKcFLPBRdVcVL48xWMsbafmLs1JMRtcVM1F5ctTPMbSdmru0pwUs8FF1VxUvjzFYyxtp+YuzUkxG1xUzUXly1M8xtJ2au7SnBSzwUXVXFS+PMVjLG2nnLuISELqzbnrk7vkY9sG+PqJ8OXe0u/bNsz+fouqQSSQcu6Sh6Ky4CH7VR5dIqJP7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iUjPnbmbzgFuAtwCvAde5+zcrxrwHuBN4qrzpDnef8S6Scu4iIuFauZ77AeBz7v6omR0GbDWzB9x9R8W4H7v78kaKlQiluH54SM0pzi8GOm7JqNnc3f1Z4Nnyz380s51AP1DZ3CUXKea1lUcvno5bUoKuuZvZfGAxsKXK26eb2S/M7F4zO6EFtUm7pJjXVh69eDpuSan7CVUzexPwfeCz7v5ixduPAn/p7i+Z2TJgE/C2Kvu4ALgA4Nhjj224aClYinlt5dGLp+OWlLo+uZtZN6XGfpu731H5vru/6O4vlX++B+g2s6OqjLvO3QfdfbCvr6/J0qUwKea1lUcvno5bUmo2dzMz4EZgp7tXXbvUzN5SHoeZnVre7+5WFiqzKMW8tvLoxdNxS0o9l2WWAB8FtpvZ4+VtXwSOBXD37wDnAp82swPAPmC1t2stYWne+M2xlFIRITWnOL8Y6LglReu5i4gkpJU5d4mVMseT3X0JbL2p9IXU1lX6ertmvwVJJFFq7qlS5niyuy+BoRsPvvaxg6/V4KUDaW2ZVClzPNnWm8K2i2ROzT1VyhxP5mNh20Uyp+aeKmWOJ7OusO0imVNzT5Uyx5OdsjZsu0jm1NxTpbXDJ1t+DQyef/CTunWVXutmqnQo5dxFRBKinHsDNj02wtX3P8Eze/ZxTG8Ply5dyMrF/e0uq3Vyz8XnPr8Y6BgnQ829bNNjI1x2x3b27S+lK0b27OOyO7YD5NHgc8/F5z6/GOgYJ0XX3Muuvv+J1xv7uH37x7j6/ifaVFGL5Z6Lz31+MdAxToqae9kze/YFbU9O7rn43OcXAx3jpKi5lx3T2xO0PTm55+Jzn18MdIyTouZedunShfR0T37gpae7i0uXLmxTRS2Wey4+9/nFQMc4KbqhWjZ+0zTbtEzua3HnPr8Y6BgnRTl3EZGE1Jtz12UZkRRs2wBfPxG+3Fv6fduGNPYtbaPLMiKxKzJfrux6tvTJXSR2RebLlV3Plpq7SOyKzJcru54tNXeR2BWZL1d2PVtq7iKxKzJfrux6ttTcRWJX5Nr9+l6AbCnnLiKSEOXcRUQ6mJq7iEiG1NxFRDKk5i4ikiE1dxGRDKm5i4hkSM1dRCRDau4iIhmq2dzNbJ6Z/ZeZ7TSzX5rZxVXGmJl9y8yeNLNtZnZyMeVKU7Rut0jHqGc99wPA59z9UTM7DNhqZg+4+44JY94PvK386zTg2+XfJRZat1uko9T85O7uz7r7o+Wf/wjsBCq/WPRs4BYveQToNbOjW16tNE7rdot0lKBr7mY2H1gMbKl4qx/YNeH1MFP/A4CZXWBmQ2Y2NDo6GlapNEfrdot0lLqbu5m9Cfg+8Fl3f7Hy7Sp/ZMqKZO5+nbsPuvtgX19fWKXSHK3bLdJR6mruZtZNqbHf5u53VBkyDMyb8HoAeKb58qRltG63SEepJy1jwI3ATne/Zpphm4GPlVMz7wb2uvuzLaxTmqV1u0U6Sj1pmSXAR4HtZvZ4edsXgWMB3P07wD3AMuBJ4E/AJ1pfqjRt0So1c5EOUbO5u/tPqH5NfeIYBy5sVVEiItIcPaEqIpIhNXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIasFFFvw19sNgr8ri1/eW1HAc+3u4gCaX7pynluoPnV4y/dvebiXG1r7jEzsyF3H2x3HUXR/NKV89xA82slXZYREcmQmruISIbU3Ku7rt0FFEzzS1fOcwPNr2V0zV1EJEP65C4ikqGObu5m1mVmj5nZ3VXeW2tmo2b2ePnX37ejxmaY2dNmtr1c/1CV983MvmVmT5rZNjM7uR11NqKOub3HzPZOOH9JfeWUmfWa2UYz+5WZ7TSz0yveT/bcQV3zS/b8mdnCCXU/bmYvmtlnK8YUfv7q+bKOnF0M7AQOn+b977n7Z2axniL8tbtPl6t9P/C28q/TgG+Xf0/FTHMD+LG7L5+1alrrm8B97n6umf058IaK91M/d7XmB4meP3d/AngnlD5AAiPADyqGFX7+OvaTu5kNAB8Abmh3LW10NnCLlzwC9JrZ0e0uqtOZ2eHAGZS+3hJ3f9Xd91QMS/bc1Tm/XJwF/MbdKx/YLPz8dWxzB74BfB54bYYxHyr/k2mjmc2bYVysHPhPM9tqZhdUeb8f2DXh9XB5WwpqzQ3gdDP7hZnda2YnzGZxTToeGAX+tXzZ8AYze2PFmJTPXT3zg3TP30SrgfVVthd+/jqyuZvZcuA5d986w7C7gPnuvgj4IXDzrBTXWkvc/WRK/wS80MzOqHi/2tcnphKfqjW3Ryk9pv1XwL8Am2a7wCYcApwMfNvdFwP/B3yhYkzK566e+aV8/gAoX25aAfx7tberbGvp+evI5k7pS79XmNnTwO3AmWZ268QB7r7b3V8pv7weOGV2S2yeuz9T/v05Stf8Tq0YMgxM/BfJAPDM7FTXnFpzc/cX3f2l8s/3AN1mdtSsF9qYYWDY3beUX2+k1AwrxyR57qhjfomfv3HvBx519/+t8l7h568jm7u7X+buA+4+n9I/mx5y949MHFNx/WsFpRuvyTCzN5rZYeM/A38L/HfFsM3Ax8p37t8N7HX3Z2e51GD1zM3M3mJmVv75VEr/W98927U2wt3/B9hlZgvLm84CdlQMS/LcQX3zS/n8TbCG6pdkYBbOX6enZSYxs3XAkLtvBi4ysxXAAeAFYG07a2vAXwA/KP//4xDg39z9PjP7RwB3/w5wD7AMeBL4E/CJNtUaqp65nQt82swOAPuA1Z7WE3v/BNxW/qf9b4FPZHLuxtWaX9Lnz8zeAPwN8A8Tts3q+dMTqiIiGerIyzIiIrlTcxcRyZCau4hIhtTcRUQypOYuIpIhNXcRkQypuYuIZEjNXUQkQ/8PK3ebjfkeaxkAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 训练数据可视化\n",
    "plt.scatter(X[:50,0],X[:50,1], label='0')\n",
    "plt.scatter(X[50:,0],X[50:,1], label='1')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----\n",
    "\n",
    "### AdaBoost in Python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AdaBoost:\n",
    "    def __init__(self, n_estimators=50, learning_rate=1.0):\n",
    "        self.clf_num = n_estimators\n",
    "        self.learning_rate = learning_rate\n",
    "    \n",
    "    # 初始化参数\n",
    "    def init_args(self, datasets, labels):\n",
    "        \n",
    "        self.X = datasets\n",
    "        self.Y = labels\n",
    "        # M行N列\n",
    "        self.M, self.N = datasets.shape\n",
    "        \n",
    "        # 弱分类器数目和集合\n",
    "        self.clf_sets = []\n",
    "        \n",
    "        # 初始化权重向量 weights\n",
    "        self.weights = [1.0/self.M]*self.M\n",
    "        \n",
    "        # G(x)系数向量 alpha\n",
    "        self.alpha = []\n",
    "        \n",
    "    # 得到最终分类器G(x)\n",
    "    def _G(self, features, labels, weights):\n",
    "        # \n",
    "        m = len(features)\n",
    "        # 无穷大\n",
    "        error = 100000.0 \n",
    "        best_v = 0.0\n",
    "        # 单维features\n",
    "        features_min = min(features)\n",
    "        features_max = max(features)\n",
    "        n_step = (features_max - features_min + self.learning_rate) // self.learning_rate\n",
    "        # print('n_step:{}'.format(n_step))\n",
    "        direct, compare_array = None, None\n",
    "        for i in range(1, int(n_step)):\n",
    "            v = features_min + self.learning_rate * i\n",
    "            \n",
    "            if v not in features:\n",
    "                # 正向误分类标识\n",
    "                compare_array_positive = np.array([1 if features[k] > v else -1 for k in range(m)])\n",
    "                # 正向误分类权重之和\n",
    "                weight_error_positive = sum([weights[k] for k in range(m) if compare_array_positive[k] != labels[k]])\n",
    "                \n",
    "                # 反向误分类标识\n",
    "                compare_array_nagetive = np.array([-1 if features[k] > v else 1 for k in range(m)])\n",
    "                # 反向误分类权重之和\n",
    "                weight_error_nagetive = sum([weights[k] for k in range(m) if compare_array_nagetive[k] != labels[k]])\n",
    "                \n",
    "                # 取两者中最小的权重\n",
    "                if weight_error_positive < weight_error_nagetive:\n",
    "                    weight_error = weight_error_positive\n",
    "                    _compare_array = compare_array_positive\n",
    "                    direct = 'positive'\n",
    "                else:\n",
    "                    weight_error = weight_error_nagetive\n",
    "                    _compare_array = compare_array_nagetive\n",
    "                    direct = 'nagetive'\n",
    "                \n",
    "                # 取全局最小的误差率\n",
    "                if weight_error < error:\n",
    "                    error = weight_error\n",
    "                    compare_array = _compare_array\n",
    "                    best_v = v\n",
    "        \n",
    "        # 得到best_v：分类误差率最低时的阈值，\n",
    "        return best_v, direct, error, compare_array\n",
    "    \n",
    "    # 计算alpha\n",
    "    def _alpha(self, error):\n",
    "        return 0.5 * np.log((1-error)/error)\n",
    "    \n",
    "    # 规范化因子\n",
    "    def _Z(self, weights, a, clf):\n",
    "        return sum([weights[i]*np.exp(-1*a*self.Y[i]*clf[i]) for i in range(self.M)])\n",
    "        \n",
    "    # 权值更新\n",
    "    def _w(self, a, clf, Z):\n",
    "        for i in range(self.M):\n",
    "            self.weights[i] = self.weights[i]*np.exp(-1*a*self.Y[i]*clf[i])/ Z\n",
    "    \n",
    "    def _f(self, alpha, clf_sets):\n",
    "        pass\n",
    "    \n",
    "    def G(self, x, v, direct):\n",
    "        if direct == 'positive':\n",
    "            return 1 if x > v else -1 \n",
    "        else:\n",
    "            return -1 if x > v else 1 \n",
    "    \n",
    "    def fit(self, X, y):\n",
    "        self.init_args(X, y)\n",
    "        \n",
    "        for epoch in range(self.clf_num):\n",
    "            # 初始化最佳误差率\n",
    "            best_clf_error, best_v, clf_result = 100000, None, None\n",
    "            # 对特征值进行遍历\n",
    "            for j in range(self.N):\n",
    "                # 获取特征所在的列\n",
    "                features = self.X[:, j]\n",
    "                # 计算v：分类阈值、error：分类误差、direct：分类结果\n",
    "                v, direct, error, compare_array = self._G(features, self.Y, self.weights)\n",
    "                \n",
    "                # 获得分类误差率最低的各指标（分类阈值，分类误差，分类结果）\n",
    "                if error < best_clf_error:\n",
    "                    best_clf_error = error\n",
    "                    best_v = v\n",
    "                    final_direct = direct\n",
    "                    clf_result = compare_array\n",
    "                    axis = j\n",
    "                \n",
    "                if best_clf_error == 0:\n",
    "                    break\n",
    "            \n",
    "            # 获得alpha\n",
    "            a = self._alpha(best_clf_error)\n",
    "            self.alpha.append(a)\n",
    "            # 保存分类器（axis：序号，best_v：阈值，final_direct：分类结果）\n",
    "            self.clf_sets.append((axis, best_v, final_direct))\n",
    "            # 计算规范化因子\n",
    "            Z = self._Z(self.weights, a, clf_result)\n",
    "            # 更新权值\n",
    "            self._w(a, clf_result, Z)\n",
    "            \n",
    "            #打印每次的迭代更新过程\n",
    "            #print('classifier:{}/{} error:{:.3f} v:{} direct:{} a:{:.5f}'.format(epoch+1, self.clf_num, error, best_v, final_direct, a))\n",
    "            #print('weight:{}'.format(self.weights))\n",
    "            #print('\\n')\n",
    "    \n",
    "    # 预测值\n",
    "    def predict(self, feature):\n",
    "        result = 0.0\n",
    "        for i in range(len(self.clf_sets)):\n",
    "            axis, clf_v, direct = self.clf_sets[i]\n",
    "            f_input = feature[axis]\n",
    "            # 计算f(x)\n",
    "            result += self.alpha[i] * self.G(f_input, clf_v, direct)\n",
    "        # signmod\n",
    "        return 1 if result > 0 else -1\n",
    "    \n",
    "    # 计算分类正确率\n",
    "    def score(self, X_test, y_test):\n",
    "        right_count = 0\n",
    "        for i in range(len(X_test)):\n",
    "            feature = X_test[i]\n",
    "            if self.predict(feature) == y_test[i]:\n",
    "                right_count += 1\n",
    "        \n",
    "        return right_count / len(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 例8.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.arange(10).reshape(10, 1)\n",
    "y = np.array([1, 1, 1, -1, -1, -1, 1, 1, 1, -1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 迭代更新3次，学习率是0.5\n",
    "clf = AdaBoost(n_estimators=3, learning_rate=0.5)\n",
    "clf.fit(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = create_data()\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.696969696969697"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf = AdaBoost(n_estimators=10, learning_rate=0.2)\n",
    "clf.fit(X_train, y_train)\n",
    "clf.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "average score:67.333%\n"
     ]
    }
   ],
   "source": [
    "result = []\n",
    "for i in range(1, 101):\n",
    "    X, y = create_data()\n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)\n",
    "    clf = AdaBoost(n_estimators=100, learning_rate=0.2)\n",
    "    clf.fit(X_train, y_train)\n",
    "    r = clf.score(X_test, y_test)\n",
    "    result.append(r)\n",
    "\n",
    "print('average score:{:.3f}%'.format(sum(result)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "-----\n",
    "# sklearn.ensemble.AdaBoostClassifier\n",
    "\n",
    "- algorithm：这个参数只有AdaBoostClassifier有。主要原因是scikit-learn实现了两种Adaboost分类算法，SAMME和SAMME.R。两者的主要区别是弱学习器权重的度量，SAMME使用了和我们的原理篇里二元分类Adaboost算法的扩展，即用对样本集分类效果作为弱学习器权重，而SAMME.R使用了对样本集分类的预测概率大小来作为弱学习器权重。由于SAMME.R使用了概率度量的连续值，迭代一般比SAMME快，因此AdaBoostClassifier的默认算法algorithm的值也是SAMME.R。我们一般使用默认的SAMME.R就够了，但是要注意的是使用了SAMME.R， 则弱分类学习器参数base_estimator必须限制使用支持概率预测的分类器。SAMME算法则没有这个限制。\n",
    "\n",
    "- n_estimators： AdaBoostClassifier和AdaBoostRegressor都有，就是我们的弱学习器的最大迭代次数，或者说最大的弱学习器的个数。一般来说n_estimators太小，容易欠拟合，n_estimators太大，又容易过拟合，一般选择一个适中的数值。默认是50。在实际调参的过程中，我们常常将n_estimators和下面介绍的参数learning_rate一起考虑。\n",
    "\n",
    "-  learning_rate:  AdaBoostClassifier和AdaBoostRegressor都有，即每个弱学习器的权重缩减系数ν\n",
    "\n",
    "- base_estimator：AdaBoostClassifier和AdaBoostRegressor都有，即我们的弱分类学习器或者弱回归学习器。理论上可以选择任何一个分类或者回归学习器，不过需要支持样本权重。我们常用的一般是CART决策树或者神经网络MLP。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AdaBoostClassifier(algorithm='SAMME.R', base_estimator=None,\n",
       "          learning_rate=0.5, n_estimators=100, random_state=None)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.ensemble import AdaBoostClassifier\n",
    "clf = AdaBoostClassifier(n_estimators=100, learning_rate=0.5)\n",
    "clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9090909090909091"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.score(X_test, y_test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
